#!/usr/bin/env python3
"""Build Paper B (Feldstein-Horioka) SSRN-ready .docx from paper.md + table files.

Uses python-docx for proper formatting: title page, abstract, rendered tables,
inline math, and appended regression output tables.
"""

import re
from pathlib import Path
from docx import Document
from docx.shared import Pt, Cm, Inches, RGBColor
from docx.enum.text import WD_ALIGN_PARAGRAPH
from docx.enum.table import WD_TABLE_ALIGNMENT
from docx.oxml.ns import nsdecls
from docx.oxml import parse_xml

PROJECT_DIR = Path(__file__).resolve().parent
PAPER_DIR = PROJECT_DIR / "paper"
TABLE_DIR = PROJECT_DIR.parent / "output" / "tables"
FIGURE_DIR = PROJECT_DIR.parent / "output" / "figures"

# Figures to embed (markdown reference text -> file path)
FIGURES = {
    "Figure 1": FIGURE_DIR / "implied_retention_by_z1.png",
}

# Appendix table files (appended after References)
APPENDIX_TABLES = [
    "kaopen_prediction.md",
    "phase5_fh_long_diff.md",
    "phase5_fh_income_robustness.md",
]


# ── YAML parsing ────────────────────────────────────────────────────────

def parse_yaml_front_matter(text):
    m = re.match(r'^---\n(.*?)\n---\n', text, re.DOTALL)
    if not m:
        return {}, text
    yaml_block = m.group(1)
    meta = {}
    for key in ['title', 'author', 'date', 'version']:
        km = re.search(rf'^{key}:\s*"?(.*?)"?\s*$', yaml_block, re.MULTILINE)
        if km:
            meta[key] = km.group(1).strip('"')
    am = re.search(r'abstract:\s*\|\s*\n(.*?)(?=\n\w+:|$)', yaml_block, re.DOTALL)
    if am:
        meta['abstract'] = re.sub(r'\n\s{2,}', '\n', am.group(1)).strip()
    km = re.search(r'^keywords:\s*"?(.*?)"?\s*$', yaml_block, re.MULTILINE)
    if km:
        meta['keywords'] = km.group(1).strip('"')
    jm = re.search(r'^jel:\s*"?(.*?)"?\s*$', yaml_block, re.MULTILINE)
    if jm:
        meta['jel'] = jm.group(1).strip('"')
    return meta, text[m.end():]


# ── Rich text helpers ───────────────────────────────────────────────────

def add_formatted_run(paragraph, text, bold=False, italic=False, size=None,
                      color=None, superscript=False):
    run = paragraph.add_run(text)
    run.bold = bold
    run.italic = italic
    if size:
        run.font.size = Pt(size)
    if color:
        run.font.color.rgb = RGBColor(*color)
    if superscript:
        run.font.superscript = True
    return run


def subscript_char(c):
    subs = {'0': '\u2080', '1': '\u2081', '2': '\u2082', '3': '\u2083',
            '4': '\u2084', '5': '\u2085', '6': '\u2086', '7': '\u2087',
            '8': '\u2088', '9': '\u2089', 'i': '\u1d62', 'j': '\u2c7c',
            'k': '\u2096', 't': '\u209c', 'n': '\u2099'}
    return subs.get(c, '_' + c)


def superscript_char(c):
    sups = {'0': '\u2070', '1': '\u00b9', '2': '\u00b2', '3': '\u00b3',
            '4': '\u2074', '5': '\u2075', '6': '\u2076', '7': '\u2077',
            '8': '\u2078', '9': '\u2079'}
    return sups.get(c, '^' + c)


def convert_math(math_text):
    """Convert LaTeX math to Unicode plaintext."""
    t = math_text
    for latex, uni in [
        (r'\text{CA}', 'CA'), (r'\text{GDP}', 'GDP'),
        (r'\text{KAOPEN}', 'KAOPEN'), (r'\text{NFA}', 'NFA'),
        (r'\text{Trade Balance}', 'Trade Balance'),
        (r'\text{TB}', 'TB'), (r'\text{NTR}', 'NTR'),
        (r'\text{GDP per capita}', 'GDP per capita'),
        (r'\ln', 'ln'), (r'\log', 'log'),
        (r'\mathbb{1}', '1'),
        (r'\alpha', '\u03b1'), (r'\beta', '\u03b2'),
        (r'\gamma', '\u03b3'), (r'\delta', '\u03b4'),
        (r'\theta', '\u03b8'), (r'\phi', '\u03c6'),
        (r'\rho', '\u03c1'),
        (r'\hat{\gamma}', '\u03b3\u0302'),
        (r'\hat{\beta}', '\u03b2\u0302'),
        (r'\hat{\delta}', '\u03b4\u0302'),
        (r'\hat{\phi}', '\u03c6\u0302'),
        (r'\hat', '\u0302'),
        (r'\Delta', '\u0394'), (r'\sum', '\u2211'),
        (r'\times', '\u00d7'), (r'\cdot', '\u00b7'),
        (r'\neq', '\u2260'), (r'\approx', '\u2248'),
        (r'\leq', '\u2264'), (r'\geq', '\u2265'),
        (r'\to', '\u2192'), (r'\rightarrow', '\u2192'),
        (r'\in', '\u2208'), (r'\ldots', '\u2026'),
        (r'\,', ' '),
    ]:
        t = t.replace(latex, uni)
    t = t.replace(r'\frac{', '')
    t = t.replace('}{', '/')
    t = t.replace('{', '').replace('}', '')
    t = re.sub(r'_(\w)', lambda m: subscript_char(m.group(1)), t)
    t = re.sub(r'\^(\w)', lambda m: superscript_char(m.group(1)), t)
    return t


def add_rich_text(paragraph, text, base_size=11, base_italic=False):
    """Parse inline markdown (bold, italic, math, citations, emdash)."""
    remaining = text
    while remaining:
        patterns = [
            (r'\*\*\*(.*?)\*\*\*', 'bold_italic'),
            (r'\*\*(.*?)\*\*', 'bold'),
            (r'\*(.*?)\*', 'italic'),
            (r'\$\$(.*?)\$\$', 'display_math'),
            (r'\$(.*?)\$', 'math'),
            (r'\[(.*?)\]', 'bracket'),
            (r'@(\w+\d{4}\w?)', 'citation'),
            (r'---', 'emdash'),
        ]
        earliest_match = None
        earliest_pos = len(remaining)
        earliest_type = None
        for pat, ptype in patterns:
            m = re.search(pat, remaining)
            if m and m.start() < earliest_pos:
                earliest_match = m
                earliest_pos = m.start()
                earliest_type = ptype

        if earliest_match is None:
            if remaining:
                add_formatted_run(paragraph, remaining, size=base_size,
                                  italic=base_italic)
            break

        if earliest_pos > 0:
            add_formatted_run(paragraph, remaining[:earliest_pos],
                              size=base_size, italic=base_italic)

        if earliest_type == 'bold_italic':
            add_formatted_run(paragraph, earliest_match.group(1),
                              bold=True, italic=True, size=base_size)
        elif earliest_type == 'bold':
            add_formatted_run(paragraph, earliest_match.group(1),
                              bold=True, size=base_size)
        elif earliest_type == 'italic':
            add_formatted_run(paragraph, earliest_match.group(1),
                              italic=True, size=base_size)
        elif earliest_type in ('math', 'display_math'):
            math_text = convert_math(earliest_match.group(1))
            add_formatted_run(paragraph, math_text, italic=True, size=base_size)
        elif earliest_type == 'citation':
            add_formatted_run(paragraph, earliest_match.group(1),
                              size=base_size)
        elif earliest_type == 'bracket':
            content = earliest_match.group(1)
            content = re.sub(r'@(\w+)', r'\1', content)
            add_formatted_run(paragraph, f'[{content}]', size=base_size)
        elif earliest_type == 'emdash':
            add_formatted_run(paragraph, '\u2014', size=base_size)

        remaining = remaining[earliest_match.end():]


# ── Table rendering ─────────────────────────────────────────────────────

def parse_markdown_table(lines):
    if len(lines) < 2:
        return None, None
    def parse_row(line):
        cells = [c.strip() for c in line.split('|')]
        if cells and cells[0] == '':
            cells = cells[1:]
        if cells and cells[-1] == '':
            cells = cells[:-1]
        return cells
    header = parse_row(lines[0])
    rows = [parse_row(l) for l in lines[2:] if l.strip()]
    return header, rows


def add_table(doc, header, rows, title=None, notes=None):
    if title:
        p = doc.add_paragraph()
        p.alignment = WD_ALIGN_PARAGRAPH.LEFT
        p.space_before = Pt(12)
        add_formatted_run(p, title, bold=True, size=11)

    ncols = len(header)
    nrows = len(rows) + 1
    table = doc.add_table(rows=nrows, cols=ncols)
    table.alignment = WD_TABLE_ALIGNMENT.CENTER
    table.style = 'Table Grid'
    table.autofit = True

    for j, cell_text in enumerate(header):
        cell = table.rows[0].cells[j]
        cell.text = ''
        p = cell.paragraphs[0]
        p.alignment = WD_ALIGN_PARAGRAPH.CENTER
        add_rich_text(p, cell_text, base_size=9, base_italic=False)
        for run in p.runs:
            run.bold = True
            run.font.size = Pt(9)
        shading = parse_xml(f'<w:shd {nsdecls("w")} w:fill="D9E2F3"/>')
        cell._tc.get_or_add_tcPr().append(shading)

    for i, row_data in enumerate(rows):
        for j, cell_text in enumerate(row_data):
            if j >= ncols:
                continue
            cell = table.rows[i + 1].cells[j]
            cell.text = ''
            p = cell.paragraphs[0]
            p.alignment = (WD_ALIGN_PARAGRAPH.LEFT if j == 0
                           else WD_ALIGN_PARAGRAPH.CENTER)
            add_rich_text(p, cell_text, base_size=9)
            if i % 2 == 1:
                shading = parse_xml(
                    f'<w:shd {nsdecls("w")} w:fill="F2F2F2"/>')
                cell._tc.get_or_add_tcPr().append(shading)

    for row in table.rows:
        for cell in row.cells:
            for p in cell.paragraphs:
                for run in p.runs:
                    if run.font.size is None:
                        run.font.size = Pt(9)

    if notes:
        p = doc.add_paragraph()
        p.space_before = Pt(2)
        p.space_after = Pt(12)
        add_rich_text(p, notes, base_size=8, base_italic=True)

    return table


# ── Body processing ─────────────────────────────────────────────────────

def process_body(doc, body):
    """Process the markdown body into the docx document."""
    lines = body.split('\n')
    i = 0
    while i < len(lines):
        line = lines[i]

        if not line.strip():
            i += 1
            continue

        # Footnotes — skip (content already in text)
        if re.match(r'^\[\^', line):
            i += 1
            continue

        # Headings
        heading_match = re.match(r'^(#{1,3})\s+(.*)', line)
        if heading_match:
            level = len(heading_match.group(1))
            heading_text = heading_match.group(2).strip()
            h = doc.add_heading(level=level)
            add_rich_text(h, heading_text,
                          base_size={1: 16, 2: 13, 3: 11}[level])
            i += 1
            continue

        # Display equations ($$...$$)
        if line.strip().startswith('$$'):
            eq_lines = [line]
            if not line.strip().endswith('$$') or line.strip() == '$$':
                i += 1
                while i < len(lines):
                    eq_lines.append(lines[i])
                    if lines[i].strip().endswith('$$'):
                        break
                    i += 1
            eq_text = ' '.join(eq_lines)
            eq_text = eq_text.replace('$$', '').strip()
            p = doc.add_paragraph()
            p.alignment = WD_ALIGN_PARAGRAPH.CENTER
            p.space_before = Pt(6)
            p.space_after = Pt(6)
            converted = convert_math(eq_text)
            add_formatted_run(p, converted, italic=True, size=11)
            i += 1
            continue

        # Table title (bold line starting with **Table)
        if line.strip().startswith('**Table'):
            table_title = line.strip().strip('*')
            i += 1
            while i < len(lines) and not lines[i].strip():
                i += 1
            table_lines = []
            while i < len(lines) and '|' in lines[i]:
                table_lines.append(lines[i])
                i += 1
            notes = None
            while i < len(lines) and not lines[i].strip():
                i += 1
            if (i < len(lines) and
                    lines[i].strip().startswith('*Notes:')):
                notes = lines[i].strip().strip('*')
                i += 1
                while (i < len(lines) and lines[i].strip()
                       and not lines[i].strip().startswith('#')
                       and not lines[i].strip().startswith('**')):
                    if (lines[i].strip().startswith('|') or
                            lines[i].strip().startswith('*Notes:')):
                        break
                    notes += ' ' + lines[i].strip()
                    i += 1
            if table_lines:
                header, rows = parse_markdown_table(table_lines)
                if header and rows:
                    add_table(doc, header, rows, title=table_title,
                              notes=notes)
            continue

        # Inline table
        if line.strip().startswith('|'):
            table_lines = []
            while i < len(lines) and '|' in lines[i]:
                table_lines.append(lines[i])
                i += 1
            if table_lines:
                header, rows = parse_markdown_table(table_lines)
                if header and rows:
                    add_table(doc, header, rows)
            continue

        # Bullet points
        bullet_match = re.match(r'^(\s*)-\s+(.*)', line)
        if bullet_match:
            bullet_text = bullet_match.group(2)
            p = doc.add_paragraph(style='List Bullet')
            add_rich_text(p, bullet_text, base_size=11)
            i += 1
            continue

        # Regular paragraph
        para_lines = [line]
        i += 1
        while i < len(lines):
            next_line = lines[i]
            if not next_line.strip():
                break
            if re.match(r'^#{1,3}\s', next_line):
                break
            if next_line.strip().startswith('|'):
                break
            if next_line.strip().startswith('**Table'):
                break
            if next_line.strip().startswith('$$'):
                break
            if re.match(r'^\s*-\s+', next_line):
                break
            if re.match(r'^\[\^', next_line):
                break
            para_lines.append(next_line)
            i += 1

        para_text = ' '.join(l.strip() for l in para_lines)
        # Strip footnote references like [^income-note]
        para_text = re.sub(r'\[\^[^\]]+\]', '', para_text)
        if para_text.strip():
            p = doc.add_paragraph()
            add_rich_text(p, para_text, base_size=11)


# ── Main ────────────────────────────────────────────────────────────────

def build():
    md_path = PAPER_DIR / "paper.md"
    docx_path = PAPER_DIR / "paper.docx"

    text = md_path.read_text(encoding='utf-8')
    meta, body = parse_yaml_front_matter(text)

    doc = Document()

    # Page setup
    for section in doc.sections:
        section.top_margin = Cm(2.54)
        section.bottom_margin = Cm(2.54)
        section.left_margin = Cm(2.54)
        section.right_margin = Cm(2.54)

    # Default font
    style = doc.styles['Normal']
    font = style.font
    font.name = 'Times New Roman'
    font.size = Pt(11)
    style.paragraph_format.space_after = Pt(6)
    style.paragraph_format.line_spacing = 1.15

    # Heading styles
    for lvl in range(1, 4):
        hs = doc.styles[f'Heading {lvl}']
        hs.font.name = 'Times New Roman'
        hs.font.color.rgb = RGBColor(0, 0, 0)
        if lvl == 1:
            hs.font.size = Pt(16)
            hs.paragraph_format.space_before = Pt(24)
            hs.paragraph_format.space_after = Pt(12)
        elif lvl == 2:
            hs.font.size = Pt(13)
            hs.paragraph_format.space_before = Pt(18)
            hs.paragraph_format.space_after = Pt(8)
        elif lvl == 3:
            hs.font.size = Pt(11)
            hs.paragraph_format.space_before = Pt(12)
            hs.paragraph_format.space_after = Pt(6)

    # ── Title page ──────────────────────────────────────────────────
    p = doc.add_paragraph()
    p.alignment = WD_ALIGN_PARAGRAPH.CENTER
    p.space_after = Pt(6)
    add_formatted_run(p, meta.get('title', 'Untitled'), bold=True, size=18)

    p = doc.add_paragraph()
    p.alignment = WD_ALIGN_PARAGRAPH.CENTER
    p.space_after = Pt(4)
    add_formatted_run(p, meta.get('author', ''), size=12)

    p = doc.add_paragraph()
    p.alignment = WD_ALIGN_PARAGRAPH.CENTER
    p.space_after = Pt(4)
    add_formatted_run(p, meta.get('date', ''), size=12)

    if 'version' in meta:
        p = doc.add_paragraph()
        p.alignment = WD_ALIGN_PARAGRAPH.CENTER
        p.space_after = Pt(18)
        add_formatted_run(p, f"Version: {meta['version']}",
                          size=10, italic=True)

    # Abstract
    if 'abstract' in meta:
        p = doc.add_paragraph()
        p.alignment = WD_ALIGN_PARAGRAPH.LEFT
        p.space_after = Pt(6)
        add_formatted_run(p, 'Abstract', bold=True, size=11)

        p = doc.add_paragraph()
        p.paragraph_format.left_indent = Cm(1.27)
        p.paragraph_format.right_indent = Cm(1.27)
        p.space_after = Pt(12)
        add_rich_text(p, meta['abstract'], base_size=10)

    # Keywords and JEL
    if 'keywords' in meta:
        p = doc.add_paragraph()
        p.space_after = Pt(2)
        add_formatted_run(p, 'Keywords: ', bold=True, size=10)
        add_formatted_run(p, meta['keywords'], size=10)
    if 'jel' in meta:
        p = doc.add_paragraph()
        p.space_after = Pt(18)
        add_formatted_run(p, 'JEL Classification: ', bold=True, size=10)
        add_formatted_run(p, meta['jel'], size=10)

    doc.add_page_break()

    # ── Paper body ──────────────────────────────────────────────────
    process_body(doc, body)

    # ── Insert figures after paragraphs that reference them ──────
    for fig_label, fig_path in FIGURES.items():
        if not fig_path.exists():
            print(f"  WARNING: Figure not found: {fig_path}")
            continue
        # Find the paragraph that starts with "Figure N plots..."
        for i, para in enumerate(doc.paragraphs):
            if fig_label in para.text and 'plots' in para.text.lower():
                # Insert image after this paragraph
                p = doc.paragraphs[i]
                # Add a new paragraph with the image right after
                run = p.add_run()
                run.add_break()
                run.add_picture(str(fig_path), width=Inches(5.5))
                break

    # ── Appendix table files ────────────────────────────────────────
    table_files = []
    for name in APPENDIX_TABLES:
        tf = TABLE_DIR / name
        if tf.exists():
            table_files.append(tf)
        else:
            print(f"  WARNING: Table not found: {tf}")

    if table_files:
        doc.add_page_break()
        h = doc.add_heading(level=1)
        add_rich_text(h, 'Online Appendix: Regression Output Tables',
                      base_size=16)
        for tf in table_files:
            doc.add_page_break()
            appendix_text = tf.read_text(encoding='utf-8')
            process_body(doc, appendix_text)

    # Save
    doc.save(str(docx_path))
    print(f"Saved: {docx_path}")
    print(f"Tables: {len(doc.tables)} (all rendered inline)")
    print(f"Appendix files: {len(table_files)}")


if __name__ == '__main__':
    build()
